Tutorial ML 2
A common task in BCI research is to test a machine learning model (MLM) on a large amount of real data. This tutorial uses the FII BCI corpus as an example.
If you did not download the corpus yet, do so before running this tutorial using the downloadDB function.
The tutorial shows how to
- Select databases and sessions from the FII BCI Corpus according to:
- BCI Paradigm (Motor Imagery or P300)
- availability of specific classes
- minimum number of trials per class
- Run a cross-validation for all selected sessions in all selected databases and store the balanced accuracies obtained on all cross-validations
As a MLM, the MDM Riemannian classifier employing the affine-invariant (Fisher-Rao) metric is used (Barachant et al., 2012), (Congedo et al., 2017). As a covariance matrix estimator, the linear shrinkage estimator of (Ledoit and Wolf, 2004) is used. These are state-of-the art settings used as default in Eegle.
For each session, an 8-fold stratified cross-validation is run. While doing computations, summary results per session will be printed, including the mean and standard deviation of the balanced accuracy obtained across the folds as well as the p-value of the cross-validation test-statistic.
Select all motor imagery databases in the FII BCI Corpus featuring the "feet" and "right_hand" class. Within these databases, select the sessions featuring at least 30 trials for each of these classes — see selectDB.
classes = ["feet", "right_hand"];
DBs = selectDB(:MI; classes, minTrials = 30);Create memory to store all accuracies.
MIacc = [zeros(length(DB.files)) for DB ∈ DBs];Perform the cross-validation on all selected sessions for all selected databases:
for (db, DB) ∈ enumerate(DBs), (f, file) ∈ enumerate(DB.files)
# perform cross-validation (using Eegle)
cv = crval(file; upperLimit = 1.2, bandPass=(8, 32), classes)
# store accuracy
MIacc[db][f] = cv.avgAcc
# print a summary of the cv results
println("\nDatabase ", DB.dbName, ", File ", f,
": mean(sd) balanced accuracy ", round(cv.avgAcc*100, digits=2),
"% (± ", round(cv.stdAcc*100, digits=2), "%); ",
"p-value ", round(cv.p; digits = 4))
endShow all MI accuracies
[round.(db; digits=2) for db ∈ MIacc]Perform the cross-validation on all available P300 databases and on all sessions featuring at least 25 trials for both the target and non-target classes. For P300 there is no need to specify these two classes as they are the default:
P300acc = [zeros(length(DB.files)) for DB ∈ DBs];
DBs = selectDB(:P300; minTrials = 25);
for (db, DB) ∈ enumerate(DBs), (f, file) ∈ enumerate(DB.files)
# perform cross-validation (using Eegle)
cv = crval(file; upperLimit=1.2, bandPass=(1, 24))
# store accuracy
P300acc[db][f] = cv.avgAcc
# print a summary of the cv results
println("\nDatabase ", DB.dbName, ", File ", f,
": mean(sd) balanced accuracy ", round(cv.avgAcc*100, digits=2),
"% (± ", round(cv.stdAcc*100, digits=2), "%); ",
"p-value ", round(cv.p; digits = 4))
endFor all possible options in running cross-validations, see Eegle.BCI.crval.